{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "u3B7Uh50lozN"
      },
      "outputs": [],
      "source": [
        "!pip install -U -q tf-nightly"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "qWUV0FYjDSKj"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "from tensorflow.contrib import autograph\n",
        "\n",
        "import matplotlib.pyplot as plt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "kGXS3UWBBNoc"
      },
      "source": [
        "# 1. AutoGraph writes graph code for you\n",
        "\n",
        "[AutoGraph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/README.md) helps you write complicated graph code using just plain Python -- behind the scenes, AutoGraph automatically transforms your code into the equivalent TF graph code. We support a large chunk of the Python language, which is growing. [Please see this document for what we currently support, and what we're working on](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/LIMITATIONS.md).\n",
        "\n",
        "Here's a quick example of how it works:\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "aA3gOodCBkOw"
      },
      "outputs": [],
      "source": [
        "# Autograph can convert functions like this...\n",
        "def g(x):\n",
        "  if x \u003e 0:\n",
        "    x = x * x\n",
        "  else:\n",
        "    x = 0.0\n",
        "  return x\n",
        "\n",
        "# ...into graph-building functions like this:\n",
        "def tf_g(x):\n",
        "  with tf.name_scope('g'):\n",
        "\n",
        "    def if_true():\n",
        "      with tf.name_scope('if_true'):\n",
        "        x_1, = x,\n",
        "        x_1 = x_1 * x_1\n",
        "        return x_1,\n",
        "\n",
        "    def if_false():\n",
        "      with tf.name_scope('if_false'):\n",
        "        x_1, = x,\n",
        "        x_1 = 0.0\n",
        "        return x_1,\n",
        "\n",
        "    x = autograph_utils.run_cond(tf.greater(x, 0), if_true, if_false)\n",
        "    return x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "I1RtBvoKBxq5"
      },
      "outputs": [],
      "source": [
        "# You can run your plain-Python code in graph mode,\n",
        "# and get the same results out, but with all the benfits of graphs:\n",
        "print('Original value: %2.2f' % g(9.0))\n",
        "\n",
        "# Generate a graph-version of g and call it:\n",
        "tf_g = autograph.to_graph(g)\n",
        "\n",
        "with tf.Graph().as_default():\n",
        "  # The result works like a regular op: takes tensors in, returns tensors.\n",
        "  # You can inspect the graph using tf.get_default_graph().as_graph_def()\n",
        "  g_ops = tf_g(tf.constant(9.0))\n",
        "  with tf.Session() as sess:\n",
        "    print('Autograph value: %2.2f\\n' % sess.run(g_ops))\n",
        "\n",
        "\n",
        "# You can view, debug and tweak the generated code:\n",
        "print(autograph.to_code(g))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "m-jWmsCmByyw"
      },
      "source": [
        "#### Automatically converting complex control flow\n",
        "\n",
        "AutoGraph can convert a large chunk of the Python language into equivalent graph-construction code, and we're adding new supported language features all the time. In this section, we'll give you a taste of some of the functionality in AutoGraph.\n",
        "AutoGraph will automatically convert most Python control flow statements into their correct graph equivalent.  \n",
        "  \n",
        "We support common statements like `while`, `for`, `if`, `break`, `return` and more. You can even nest them as much as you like. Imagine trying to write the graph version of this code by hand:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "toxKBOXbB1ro"
      },
      "outputs": [],
      "source": [
        "# Continue in a loop\n",
        "def f(l):\n",
        "  s = 0\n",
        "  for c in l:\n",
        "    if c % 2 \u003e 0:\n",
        "      continue\n",
        "    s += c\n",
        "  return s\n",
        "\n",
        "print('Original value: %d' % f([10,12,15,20]))\n",
        "\n",
        "tf_f = autograph.to_graph(f)\n",
        "with tf.Graph().as_default():\n",
        "  with tf.Session():\n",
        "    print('Graph value: %d\\n\\n' % tf_f(tf.constant([10,12,15,20])).eval())\n",
        "\n",
        "print(autograph.to_code(f))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "FUJJ-WTdCGeq"
      },
      "source": [
        "Try replacing the `continue` in the above code with `break` -- AutoGraph supports that as well!  \n",
        "  \n",
        "Let's try some other useful Python constructs, like `print` and `assert`. We automatically convert Python `assert` statements into the equivalent `tf.Assert` code.  "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "IAOgh62zCPZ4"
      },
      "outputs": [],
      "source": [
        "def f(x):\n",
        "  assert x != 0, 'Do not pass zero!'\n",
        "  return x * x\n",
        "\n",
        "tf_f = autograph.to_graph(f)\n",
        "with tf.Graph().as_default():\n",
        "  with tf.Session():\n",
        "    try:\n",
        "      print(tf_f(tf.constant(0)).eval())\n",
        "    except tf.errors.InvalidArgumentError as e:\n",
        "      print('Got error message:\\n%s' % e.message)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "KRu8iIPBCQr5"
      },
      "source": [
        "You can also use plain Python `print` functions in in-graph"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "ySTsuxnqCTQi"
      },
      "outputs": [],
      "source": [
        "def f(n):\n",
        "  if n \u003e= 0:\n",
        "    while n \u003c 5:\n",
        "      n += 1\n",
        "      print(n)\n",
        "  return n\n",
        "\n",
        "tf_f = autograph.to_graph(f)\n",
        "with tf.Graph().as_default():\n",
        "  with tf.Session():\n",
        "    tf_f(tf.constant(0)).eval()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "NqF0GT-VCVFh"
      },
      "source": [
        "Appending to lists in loops also works (we create a tensor list ops behind the scenes)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "ABX070KwCczR"
      },
      "outputs": [],
      "source": [
        "def f(n):\n",
        "  z = []\n",
        "  # We ask you to tell us the element dtype of the list\n",
        "  autograph.set_element_type(z, tf.int32)\n",
        "  for i in range(n):\n",
        "    z.append(i)\n",
        "  # when you're done with the list, stack it\n",
        "  # (this is just like np.stack)\n",
        "  return autograph.stack(z)\n",
        "\n",
        "tf_f = autograph.to_graph(f)\n",
        "with tf.Graph().as_default():\n",
        "  with tf.Session():\n",
        "    print(tf_f(tf.constant(3)).eval())\n",
        "\n",
        "print('\\n\\n'+autograph.to_code(f))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "iu5IF7n2Df7C"
      },
      "outputs": [],
      "source": [
        "def fizzbuzz(num):\n",
        "  if num % 3 == 0 and num % 5 == 0:\n",
        "      print('FizzBuzz')\n",
        "  elif num % 3 == 0:\n",
        "      print('Fizz')\n",
        "  elif num % 5 == 0:\n",
        "      print('Buzz')\n",
        "  else:\n",
        "      print(num)\n",
        "  return num"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "EExAjWuwDPpR"
      },
      "outputs": [],
      "source": [
        "tf_g = autograph.to_graph(fizzbuzz)\n",
        "\n",
        "with tf.Graph().as_default():\n",
        "  # The result works like a regular op: takes tensors in, returns tensors.\n",
        "  # You can inspect the graph using tf.get_default_graph().as_graph_def()\n",
        "  g_ops = tf_g(tf.constant(15))\n",
        "  with tf.Session() as sess:\n",
        "    sess.run(g_ops)    \n",
        "  \n",
        "# You can view, debug and tweak the generated code:\n",
        "print('\\n')\n",
        "print(autograph.to_code(fizzbuzz))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "SzpKGzVpBkph"
      },
      "source": [
        "# De-graphify Exercises\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "8k23dxcSmmXq"
      },
      "source": [
        "#### Easy print statements"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "dE1Vsmp-mlpK"
      },
      "outputs": [],
      "source": [
        "# See what happens when you turn AutoGraph off.\n",
        "# Do you see the type or the value of x when you print it?\n",
        "\n",
        "# @autograph.convert()\n",
        "def square_log(x):\n",
        "  x = x * x\n",
        "  print('Squared value of x =', x)\n",
        "  return x\n",
        "\n",
        "\n",
        "with tf.Graph().as_default():\n",
        "  with tf.Session() as sess:\n",
        "    print(sess.run(square_log(tf.constant(4))))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_R-Q7BbxmkBF"
      },
      "source": [
        "#### Convert the TensorFlow code into Python code for AutoGraph"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "SwA11tO-yCvg"
      },
      "outputs": [],
      "source": [
        "def square_if_positive(x):\n",
        "  x = tf.cond(tf.greater(x, 0), lambda: x * x, lambda: x)\n",
        "  return x\n",
        "\n",
        "with tf.Session() as sess:\n",
        "  print(sess.run(square_if_positive(tf.constant(4))))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "GPmx4CNhyPI_"
      },
      "outputs": [],
      "source": [
        "@autograph.convert()\n",
        "def square_if_positive(x):\n",
        "\n",
        "  pass # TODO: fill it in!\n",
        "\n",
        "\n",
        "with tf.Session() as sess:\n",
        "  print(sess.run(square_if_positive(tf.constant(4))))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "qqsjik-QyA9R"
      },
      "source": [
        "#### Uncollapse to see answer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "DaSmaWUEvMRv"
      },
      "outputs": [],
      "source": [
        "# Simple cond\n",
        "@autograph.convert()\n",
        "def square_if_positive(x):\n",
        "  if x \u003e 0:\n",
        "    x = x * x\n",
        "  return x\n",
        "\n",
        "with tf.Graph().as_default():  \n",
        "  with tf.Session() as sess:\n",
        "    print(sess.run(square_if_positive(tf.constant(4))))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "qj7am2I_xvTJ"
      },
      "source": [
        "#### Nested If statement"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "4yyNOf-Twr6s"
      },
      "outputs": [],
      "source": [
        "def nearest_odd_square(x):\n",
        "\n",
        "    def if_positive():\n",
        "      x1 = x * x\n",
        "      x1 = tf.cond(tf.equal(x1 % 2, 0), lambda: x1 + 1, lambda: x1)\n",
        "      return x1,\n",
        "\n",
        "    x = tf.cond(tf.greater(x, 0), if_positive, lambda: x)\n",
        "    return x\n",
        "\n",
        "with tf.Graph().as_default():\n",
        "  with tf.Session() as sess:\n",
        "    print(sess.run(nearest_odd_square(tf.constant(4))))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "hqmh5b2VyU9w"
      },
      "outputs": [],
      "source": [
        "@autograph.convert()\n",
        "def nearest_odd_square(x):\n",
        "\n",
        "  pass # TODO: fill it in!\n",
        "\n",
        "\n",
        "with tf.Session() as sess:\n",
        "  print(sess.run(nearest_odd_square(tf.constant(4))))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "b9AXIkNLxp6J"
      },
      "source": [
        "#### Uncollapse to reveal answer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "8RlCVEpNxD91"
      },
      "outputs": [],
      "source": [
        "@autograph.convert()\n",
        "def nearest_odd_square(x):\n",
        "  if x \u003e 0:\n",
        "    x = x * x\n",
        "    if x % 2 == 0:\n",
        "      x = x + 1\n",
        "  return x\n",
        "\n",
        "with tf.Graph().as_default():\n",
        "  with tf.Session() as sess:\n",
        "    print(sess.run(nearest_odd_square(tf.constant(4))))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "jXAxjeBr1qWK"
      },
      "source": [
        "#### Convert a while loop"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "kWkv7anlxoee"
      },
      "outputs": [],
      "source": [
        "# Convert a while loop\n",
        "def square_until_stop(x, y):\n",
        "  x = tf.while_loop(lambda x: tf.less(x, y), lambda x: x * x, [x])\n",
        "  return x\n",
        "\n",
        "with tf.Graph().as_default():\n",
        "  with tf.Session() as sess:\n",
        "    print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "zVUsc1eA1u2K"
      },
      "outputs": [],
      "source": [
        "@autograph.convert()\n",
        "def square_until_stop(x, y):\n",
        "\n",
        "  pass # TODO: fill it in!\n",
        "\n",
        "\n",
        "with tf.Graph().as_default():\n",
        "  with tf.Session() as sess:\n",
        "    print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "L2psuzPI02S9"
      },
      "source": [
        "#### Uncollapse for the answer\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "ucmZyQVL03bF"
      },
      "outputs": [],
      "source": [
        "@autograph.convert()\n",
        "def square_until_stop(x, y):\n",
        "  while x \u003c y:\n",
        "    x = x * x\n",
        "  return x\n",
        "\n",
        "with tf.Graph().as_default():\n",
        "  with tf.Session() as sess:\n",
        "    print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "FXB0Zbwl13PY"
      },
      "source": [
        "#### Nested loop and conditional"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "clGymxdf15Ig"
      },
      "outputs": [],
      "source": [
        "@autograph.convert()\n",
        "def argwhere_cumsum(x, threshold):\n",
        "  current_sum = 0.0\n",
        "  idx = 0\n",
        "\n",
        "  for i in range(len(x)):\n",
        "    idx = i\n",
        "    if current_sum \u003e= threshold:\n",
        "      break\n",
        "    current_sum += x[i]\n",
        "  return idx\n",
        "\n",
        "n = 10\n",
        "with tf.Graph().as_default():\n",
        "  with tf.Session() as sess:\n",
        "    idx = argwhere_cumsum(tf.ones(n), tf.constant(float(n / 2)))\n",
        "    print(sess.run(idx))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "i7PF-uId9lp5"
      },
      "outputs": [],
      "source": [
        "@autograph.convert()\n",
        "def argwhere_cumsum(x, threshold):\n",
        "\n",
        "  pass # TODO: fill it in!\n",
        "\n",
        "\n",
        "n = 10\n",
        "with tf.Graph().as_default():\n",
        "  with tf.Session() as sess:\n",
        "    idx = argwhere_cumsum(tf.ones(n), tf.constant(float(n / 2)))\n",
        "    print(sess.run(idx))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "weKFXAb615Vp"
      },
      "source": [
        "#### Uncollapse to see answer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "1sjaFcL717Ig"
      },
      "outputs": [],
      "source": [
        "@autograph.convert()\n",
        "def argwhere_cumsum(x, threshold):\n",
        "  current_sum = 0.0\n",
        "  idx = 0\n",
        "  for i in range(len(x)):\n",
        "    idx = i\n",
        "    if current_sum \u003e= threshold:\n",
        "      break\n",
        "    current_sum += x[i]\n",
        "  return idx\n",
        "\n",
        "n = 10\n",
        "with tf.Graph().as_default():  \n",
        "  with tf.Session() as sess:\n",
        "    idx = argwhere_cumsum(tf.ones(n), tf.constant(float(n / 2)))\n",
        "    print(sess.run(idx))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "4LfnJjm0Bm0B"
      },
      "source": [
        "# 3. Training MNIST in-graph\n",
        "\n",
        "Writing control flow in AutoGraph is easy, so running a training loop in a TensorFlow graph should be easy as well!  \n",
        "\n",
        "Here, we show an example of training a simple Keras model on MNIST, where the entire training process -- loading batches, calculating gradients, updating parameters, calculating validation accuracy, and repeating until convergence -- is done in-graph."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Em5dzSUOtLRP"
      },
      "source": [
        "#### Download data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "xqoxumv0ssQW"
      },
      "outputs": [],
      "source": [
        "import gzip\n",
        "import os\n",
        "import shutil\n",
        "\n",
        "from six.moves import urllib\n",
        "\n",
        "\n",
        "def download(directory, filename):\n",
        "  filepath = os.path.join(directory, filename)\n",
        "  if tf.gfile.Exists(filepath):\n",
        "    return filepath\n",
        "  if not tf.gfile.Exists(directory):\n",
        "    tf.gfile.MakeDirs(directory)\n",
        "  url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n",
        "  zipped_filepath = filepath + '.gz'\n",
        "  print('Downloading %s to %s' % (url, zipped_filepath))\n",
        "  urllib.request.urlretrieve(url, zipped_filepath)\n",
        "  with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:\n",
        "    shutil.copyfileobj(f_in, f_out)\n",
        "  os.remove(zipped_filepath)\n",
        "  return filepath\n",
        "\n",
        "\n",
        "def dataset(directory, images_file, labels_file):\n",
        "  images_file = download(directory, images_file)\n",
        "  labels_file = download(directory, labels_file)\n",
        "\n",
        "  def decode_image(image):\n",
        "    # Normalize from [0, 255] to [0.0, 1.0]\n",
        "    image = tf.decode_raw(image, tf.uint8)\n",
        "    image = tf.cast(image, tf.float32)\n",
        "    image = tf.reshape(image, [784])\n",
        "    return image / 255.0\n",
        "\n",
        "  def decode_label(label):\n",
        "    label = tf.decode_raw(label, tf.uint8)\n",
        "    label = tf.reshape(label, [])\n",
        "    return tf.to_int32(label)\n",
        "\n",
        "  images = tf.data.FixedLengthRecordDataset(\n",
        "      images_file, 28 * 28, header_bytes=16).map(decode_image)\n",
        "  labels = tf.data.FixedLengthRecordDataset(\n",
        "      labels_file, 1, header_bytes=8).map(decode_label)\n",
        "  return tf.data.Dataset.zip((images, labels))\n",
        "\n",
        "\n",
        "def mnist_train(directory):\n",
        "  return dataset(directory, 'train-images-idx3-ubyte',\n",
        "                 'train-labels-idx1-ubyte')\n",
        "\n",
        "def mnist_test(directory):\n",
        "  return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "znmy4l8ntMvW"
      },
      "source": [
        "#### Define the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "Pe-erWQdBoC5"
      },
      "outputs": [],
      "source": [
        "def mlp_model(input_shape):\n",
        "  model = tf.keras.Sequential((\n",
        "      tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),\n",
        "      tf.keras.layers.Dense(100, activation='relu'),\n",
        "      tf.keras.layers.Dense(10, activation='softmax')))\n",
        "  model.build()\n",
        "  return model\n",
        "\n",
        "\n",
        "def predict(m, x, y):\n",
        "  y_p = m(x)\n",
        "  losses = tf.keras.losses.categorical_crossentropy(y, y_p)\n",
        "  l = tf.reduce_mean(losses)\n",
        "  accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n",
        "  accuracy = tf.reduce_mean(accuracies)\n",
        "  return l, accuracy\n",
        "\n",
        "\n",
        "def fit(m, x, y, opt):\n",
        "  l, accuracy = predict(m, x, y)\n",
        "  opt.minimize(l)\n",
        "  return l, accuracy\n",
        "\n",
        "\n",
        "def setup_mnist_data(is_training, hp, batch_size):\n",
        "  if is_training:\n",
        "    ds = mnist_train('/tmp/autograph_mnist_data')\n",
        "    ds = ds.shuffle(batch_size * 10)\n",
        "  else:\n",
        "    ds = mnist_test('/tmp/autograph_mnist_data')\n",
        "  ds = ds.repeat()\n",
        "  ds = ds.batch(batch_size)\n",
        "  return ds\n",
        "\n",
        "\n",
        "def get_next_batch(ds):\n",
        "  itr = ds.make_one_shot_iterator()\n",
        "  image, label = itr.get_next()\n",
        "  x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))\n",
        "  y = tf.one_hot(tf.squeeze(label), 10)\n",
        "  return x, y"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "oeYV6mKnJGMr"
      },
      "source": [
        "#### Define the training loop"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "3xtg_MMhJETd"
      },
      "outputs": [],
      "source": [
        "def train(train_ds, test_ds, hp):\n",
        "  m = mlp_model((28 * 28,))\n",
        "  opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n",
        "\n",
        "  # We'd like to save our losses to a list. In order for AutoGraph\n",
        "  # to convert these lists into their graph equivalent,\n",
        "  # we need to specify the element type of the lists.\n",
        "  train_losses = []\n",
        "  test_losses = []\n",
        "  train_accuracies = []\n",
        "  test_accuracies = []\n",
        "  autograph.set_element_type(train_losses, tf.float32)\n",
        "  autograph.set_element_type(test_losses, tf.float32)\n",
        "  autograph.set_element_type(train_accuracies, tf.float32)\n",
        "  autograph.set_element_type(test_accuracies, tf.float32)\n",
        "\n",
        "  # This entire training loop will be run in-graph.\n",
        "  i = tf.constant(0)\n",
        "  while i \u003c hp.max_steps:\n",
        "    train_x, train_y = get_next_batch(train_ds)\n",
        "    test_x, test_y = get_next_batch(test_ds)\n",
        "\n",
        "    step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)\n",
        "    step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n",
        "\n",
        "    if i % (hp.max_steps // 10) == 0:\n",
        "      print('Step', i, 'train loss:', step_train_loss, 'test loss:',\n",
        "            step_test_loss, 'train accuracy:', step_train_accuracy,\n",
        "            'test accuracy:', step_test_accuracy)\n",
        "\n",
        "    train_losses.append(step_train_loss)\n",
        "    test_losses.append(step_test_loss)\n",
        "    train_accuracies.append(step_train_accuracy)\n",
        "    test_accuracies.append(step_test_accuracy)\n",
        "\n",
        "    i += 1\n",
        "\n",
        "  # We've recorded our loss values and accuracies\n",
        "  # to a list in a graph with AutoGraph's help.\n",
        "  # In order to return the values as a Tensor,\n",
        "  # we need to stack them before returning them.\n",
        "  return (\n",
        "      autograph.stack(train_losses),\n",
        "      autograph.stack(test_losses),\n",
        "      autograph.stack(train_accuracies),\n",
        "      autograph.stack(test_accuracies),\n",
        "  )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "HYh6MSZyJOag"
      },
      "outputs": [],
      "source": [
        "with tf.Graph().as_default():\n",
        "  hp = tf.contrib.training.HParams(\n",
        "      learning_rate=0.05,\n",
        "      max_steps=500,\n",
        "  )\n",
        "  train_ds = setup_mnist_data(True, hp, 50)\n",
        "  test_ds = setup_mnist_data(False, hp, 1000)\n",
        "  tf_train = autograph.to_graph(train)\n",
        "  loss_tensors = tf_train(train_ds, test_ds, hp)\n",
        "\n",
        "  with tf.Session() as sess:\n",
        "    sess.run(tf.global_variables_initializer())\n",
        "    (\n",
        "        train_losses,\n",
        "        test_losses,\n",
        "        train_accuracies,\n",
        "        test_accuracies\n",
        "    ) = sess.run(loss_tensors)\n",
        "\n",
        "    plt.title('MNIST train/test losses')\n",
        "    plt.plot(train_losses, label='train loss')\n",
        "    plt.plot(test_losses, label='test loss')\n",
        "    plt.legend()\n",
        "    plt.xlabel('Training step')\n",
        "    plt.ylabel('Loss')\n",
        "    plt.show()\n",
        "    plt.title('MNIST train/test accuracies')\n",
        "    plt.plot(train_accuracies, label='train accuracy')\n",
        "    plt.plot(test_accuracies, label='test accuracy')\n",
        "    plt.legend(loc='lower right')\n",
        "    plt.xlabel('Training step')\n",
        "    plt.ylabel('Accuracy')\n",
        "    plt.show()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "qqsjik-QyA9R",
        "b9AXIkNLxp6J",
        "L2psuzPI02S9",
        "weKFXAb615Vp",
        "Em5dzSUOtLRP"
      ],
      "default_view": {},
      "name": "AutoGraph Workshop.ipynb",
      "provenance": [
        {
          "file_id": "1kE2gz_zuwdYySL4K2HQSz13uLCYi-fYP",
          "timestamp": 1530563781803
        }
      ],
      "version": "0.3.2",
      "views": {}
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
